import torch
from sklearn.metrics import roc_auc_score

@torch.no_grad()
def test(model, x, data, sens, logits=None, evaluator=None, device="cuda", use_sens_as_labels=False):
    if logits is None:
        model.eval()
        logits = inference_full_batch(model, x, data.edge_index, data.edge_attr.reshape(-1, 1))
    
    accs = []
    dps = []
    eos = []
    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        pred = logits[mask].max(1)[1]
        if use_sens_as_labels:
            labels = sens[mask]
        else:
            labels = data.y[mask]
        
        acc = pred.eq(labels.squeeze()).sum().item() / mask.sum().item()
        # roc_auc_score = roc_auc_score(labels.numpy(), pred.numpy())
        dp = abs(pred[sens[mask]].sum().item() / (sens[mask]).sum().item() \
            - pred[(~sens)[mask]].sum().item() / ((~sens)[mask]).sum().item())
        try:
            eo_1 = pred[sens[mask] & labels.bool()].sum().item() / (sens[mask] & labels.bool()).sum().item()
        except ZeroDivisionError:
            eo_1 = 0
        try:
            eo_2 = pred[(~sens)[mask] & labels.bool()].sum().item() / ((~sens)[mask] & labels.bool()).sum().item()
        except ZeroDivisionError:
            eo_2 = 0
        eo = abs(eo_1 - eo_2)
            
        accs.append(acc)
        dps.append(dp)
        eos.append(eo)
    return accs, dps, eos, logits

def inference_full_batch(model, x, edge_index, edge_weight):
    out = model(x, edge_index, edge_weight)

    return out
